package websocket

import (
	"errors"
	"github.com/gorilla/websocket"
	uuid "github.com/satori/go.uuid"
	"net/http"
	"sync"
	"time"
)

var (
	upgrade = websocket.Upgrader{
		ReadBufferSize:  1024,
		WriteBufferSize: 1024,
		CheckOrigin: func(r *http.Request) bool {
			return true
		},
	}
)

type IEvent interface {
	OnConnect(clientId string)
	OnMessage(clientId string, data []byte)
	OnClose(clientId string)
}

type Gateway struct {
	event  IEvent
	conns  map[string]*Connection
	users  map[string]*Connection
	groups map[string]*Group
	mutex  sync.RWMutex
}

type Connection struct {
	clientId  string
	uid       string
	wsConn    *websocket.Conn        // websocket 连接句柄
	readChan  chan *WsMessage        // 读消息
	closeChan chan byte              // 关闭信号
	data      map[string]interface{} //链接绑定的数组 同 gatwaywork  session
	once      sync.Once              // 线程安全通知连接已关闭
	groups    map[string]Nil         //链接加入了哪些分组
}

type Nil struct{}

type Group struct {
	mutex sync.RWMutex
	conns map[string]*Connection
}

type WsMessage struct {
	Type    int    // 协议
	Message []byte //消息体
}

func New(event IEvent) *Gateway {
	return &Gateway{
		event:  event,
		conns:  make(map[string]*Connection),
		users:  make(map[string]*Connection),
		groups: make(map[string]*Group),
	}
}

// 关闭链接，此处通过发送一个关闭信号，读取消息时通过管道接收到这信号时真正的关闭连接
func (c *Connection) SendCloseSign() {
	c.once.Do(func() {
		c.closeChan <- 1
		close(c.closeChan)
	})
}

// 发送心跳
func (c *Connection) heartbeat() error {
	ticker := time.NewTicker(55 * time.Second)
	defer ticker.Stop()

	for range ticker.C {
		err := c.wsConn.WriteMessage(websocket.PingMessage, nil)
		if err != nil {
			c.SendCloseSign()
			break
		}
	}
	return nil
}

// 将接收的消息写入管道
func (c *Connection) readLoop() {
	for {
		messageType, message, err := c.wsConn.ReadMessage()
		if err != nil {
			c.SendCloseSign()
			break
		}
		c.readChan <- &WsMessage{Type: messageType, Message: message}
	}
}

// 初始化链接
func (p *Gateway) initConn(conn *websocket.Conn, clientId string) *Connection {
	wsconn := &Connection{
		wsConn:    conn,
		clientId:  clientId,
		readChan:  make(chan *WsMessage, 1000),
		closeChan: make(chan byte, 1),
		groups:    make(map[string]Nil),
		data:      map[string]interface{}{},
	}

	go wsconn.heartbeat()
	go wsconn.readLoop()

	return wsconn
}

// 启动服务
func (p *Gateway) Run(w http.ResponseWriter, r *http.Request) {
	conn, err := upgrade.Upgrade(w, r, nil)
	if err != nil {
		http.Error(w, "连接失败", http.StatusInternalServerError)
		return
	}
	clientId := uuid.NewV4().String()
	wsconn := p.initConn(conn, clientId)
	p.mutex.Lock()
	p.conns[clientId] = wsconn
	p.mutex.Unlock()

	p.event.OnConnect(clientId)

	for {
		select {
		case message := <-wsconn.readChan:
			p.event.OnMessage(clientId, message.Message)
		case <-wsconn.closeChan:
			p.CloseClient(clientId)
		}
	}
}

// 通过 clientid获取连接
func (p *Gateway) getConnByClientId(clientId string) (*Connection, error) {
	p.mutex.RLock()
	defer p.mutex.RUnlock()
	wsconn, ok := p.conns[clientId]
	if !ok {
		return nil, errors.New("cliendId not found")
	}
	return wsconn, nil
}

// 通过 uid获取连接
func (p *Gateway) getConnByUid(uid string) (*Connection, error) {
	p.mutex.RLock()
	defer p.mutex.RUnlock()
	wsconn, ok := p.users[uid]
	if !ok {
		return nil, errors.New("uid not found")
	}
	return wsconn, nil
}

// 关闭链接
func (p *Gateway) CloseClient(clientId string) error {
	conn, err := p.getConnByClientId(clientId)
	if err != nil {
		return err
	}
	conn.wsConn.Close()

	//检测当前链接加入的分组并且退出分组
	list, _ := p.GetGroupsByClientId(clientId)
	for _, v := range list {
		p.LeaveGroup(v, clientId)
	}

	p.mutex.Lock()
	delete(p.users, conn.uid)
	delete(p.conns, clientId)
	p.mutex.Unlock()

	p.event.OnClose(clientId)

	return nil
}

// 向当前单个链接id发送消息
func (p *Gateway) SendToClient(clientId string, msg string) error {
	conn, err := p.getConnByClientId(clientId)
	if err != nil {
		return err
	}
	conn.wsConn.WriteMessage(websocket.TextMessage, []byte(msg))
	return nil
}

// 向当所有链接发送消息
func (p *Gateway) SendToAll(msg string) {
	for _, conn := range p.conns {
		conn.wsConn.WriteMessage(websocket.TextMessage, []byte(msg))
	}
}

// 判断clientId是否还在线
func (p *Gateway) IsOnline(clientId string) bool {
	p.mutex.RLock()
	defer p.mutex.RUnlock()
	_, ok := p.conns[clientId]
	return ok
}

// 将client_id与uid绑定
func (p *Gateway) BindUid(clientId string, uid string) error {
	p.mutex.Lock()
	defer p.mutex.Unlock()
	_, ok := p.users[uid]
	if ok {
		return errors.New("uid is bind!")
	}
	conn, ok := p.conns[clientId]
	if !ok {
		return errors.New("client not fount")
	}
	conn.uid = uid
	p.users[uid] = conn
	return nil
}

// 将uid解除绑定
func (p *Gateway) UnBindUid(uid string) error {
	p.mutex.Lock()
	defer p.mutex.Unlock()
	delete(p.users, uid)
	return nil
}

// 判断uid是否还在线
func (p *Gateway) IsUidOnline(uid string) bool {
	p.mutex.RLock()
	defer p.mutex.RUnlock()
	_, ok := p.users[uid]
	return ok
}

// 根据uid获取clientId
func (p *Gateway) GetClientIdByUid(uid string) (string, error) {
	p.mutex.RLock()
	defer p.mutex.RUnlock()
	conn, ok := p.users[uid]
	if !ok {
		return "", errors.New("uid is bind!")
	}
	return conn.clientId, nil
}

// 根据clientId获取uid
func (p *Gateway) GetUidByClientId(cliendId string) (string, error) {
	p.mutex.RLock()
	defer p.mutex.RUnlock()
	conn, ok := p.conns[cliendId]
	if !ok {
		return "", errors.New("uid is bind!")
	}
	return conn.uid, nil
}

// 根据uid向链接发送消息
func (p *Gateway) SendToUid(uid, msg string) error {
	conn, err := p.getConnByUid(uid)
	if err != nil {
		return err
	}
	conn.wsConn.WriteMessage(websocket.TextMessage, []byte(msg))
	return nil
}

// 加入分组
func (p *Gateway) JoinGroup(groupId string, clientId string) error {
	p.mutex.RLock()
	conn, ok := p.conns[clientId]
	p.mutex.RUnlock()
	if !ok {
		return errors.New("client not fount")
	}
	if _, ok := conn.groups[groupId]; ok {
		return errors.New("client is join group")
	}
	p.mutex.RLock()
	groupHandle, o := p.groups[groupId]
	p.mutex.RUnlock()
	if !o {
		groupHandle = &Group{
			conns: map[string]*Connection{},
		}
	}

	groupHandle.mutex.Lock()
	groupHandle.conns[clientId] = conn
	groupHandle.mutex.Unlock()

	p.mutex.Lock()
	p.groups[groupId] = groupHandle
	p.mutex.Unlock()

	conn.groups[groupId] = Nil{}

	return nil
}

// 离开分组
func (p *Gateway) LeaveGroup(groupId string, clientId string) error {
	p.mutex.RLock()
	conn, ok := p.conns[clientId]
	p.mutex.RUnlock()
	if !ok {
		return errors.New("client not fount")
	}
	groupHandle, o := p.groups[groupId]
	if !o {
		return errors.New("group is not found")
	}
	p.mutex.RLock()
	delete(groupHandle.conns, clientId)
	p.mutex.RUnlock()

	delete(conn.groups, groupId)
	return nil
}

// 向分组发送消息
func (p *Gateway) SendToGroup(groupId string, msg string) error {
	p.mutex.RLock()
	defer p.mutex.RUnlock()
	groupHandle, o := p.groups[groupId]
	if !o {
		return errors.New("group is not found")
	}
	if len(groupHandle.conns) > 0 {
		for _, conn := range groupHandle.conns {
			err := conn.wsConn.WriteMessage(websocket.TextMessage, []byte(msg))
			if err != nil {
				return err
			}
		}
	}
	return nil
}

// 解散分组
func (p *Gateway) UnGroup(groupId string) error {
	p.mutex.RLock()
	defer p.mutex.RUnlock()
	_, o := p.groups[groupId]
	if !o {
		return errors.New("group is not found")
	}
	delete(p.groups, groupId)
	return nil
}

// 获取当前链接加入了那些分组
func (p *Gateway) GetGroupsByClientId(clientId string) ([]string, error) {
	p.mutex.RLock()
	conn, ok := p.conns[clientId]
	p.mutex.RUnlock()
	if !ok {
		return nil, errors.New("client not fount")
	}
	groups := make([]string, 0)
	for k, _ := range conn.groups {
		groups = append(groups, k)
	}
	return groups, nil
}

// 根据分组id获取当前分组的连接数量
func (p *Gateway) GetClientIdCountByGroup(groupId string) (int, error) {
	p.mutex.RLock()
	defer p.mutex.RUnlock()
	_, o := p.groups[groupId]
	if !o {
		return 0, errors.New("group is not found")
	}
	return len(p.groups[groupId].conns), nil
}

// 获取当前在线连接总数
func (p *Gateway) GetAllClientIdCount() int {
	p.mutex.RLock()
	defer p.mutex.RUnlock()
	return len(p.conns)
}

// 获取某个分组所有在线client_id列表
func (p *Gateway) GetClientIdListByGroup(groupId string) ([]string, error) {
	p.mutex.RLock()
	defer p.mutex.RUnlock()
	conn, o := p.groups[groupId]
	if !o {
		return nil, errors.New("group is not found")
	}
	clientIds := make([]string, 0)
	for _, v := range conn.conns {
		clientIds = append(clientIds, v.clientId)
	}
	return clientIds, nil
}

// 获取全局所有在线client_id列表
func (p *Gateway) GetAllClientIdList() ([]string, error) {
	p.mutex.RLock()
	defer p.mutex.RUnlock()
	clientIds := make([]string, 0)
	for _, v := range p.conns {
		clientIds = append(clientIds, v.clientId)
	}
	return clientIds, nil
}

// 获取某个分组所有在线uid列表
func (p *Gateway) GetUidListByGroup(groupId string) ([]string, error) {
	p.mutex.RLock()
	defer p.mutex.RUnlock()
	conn, o := p.groups[groupId]
	if !o {
		return nil, errors.New("group is not found")
	}
	uids := make([]string, 0)
	for _, v := range conn.conns {
		uids = append(uids, v.uid)
	}
	return uids, nil
}

// 获取某个分组下的在线uid数量
func (p *Gateway) GetUidCountByGroup(groupId string) (int, error) {
	list, err := p.GetUidListByGroup(groupId)
	if err != nil {
		return 0, err
	}
	return len(list), nil
}

// 获取全局所有在线uid列表
func (p *Gateway) GetAllUidList() []string {
	p.mutex.RLock()
	defer p.mutex.RUnlock()
	list := make([]string, 0)
	for k, _ := range p.users {
		list = append(list, k)
	}
	return list
}

// 获取全局所有在线uid列表
func (p *Gateway) GetAllUidCount() int {
	return len(p.GetAllUidList())
}

// 获取全局所有在线group id列表
func (p *Gateway) GetAllGroupIdList() []string {
	p.mutex.RLock()
	defer p.mutex.RUnlock()
	list := make([]string, 0)
	for k, _ := range p.groups {
		list = append(list, k)
	}
	return list
}

// 设置某个client_id对应的session
func (p *Gateway) SetSession(clientId string, data map[string]interface{}) error {
	p.mutex.RLock()
	conn, ok := p.conns[clientId]
	p.mutex.RUnlock()
	if !ok {
		return errors.New("client not fount")
	}
	conn.data = data
	return nil
}

// 获取某个client_id对应的session
func (p *Gateway) GetSession(clientId string) (map[string]interface{}, error) {
	p.mutex.RLock()
	conn, ok := p.conns[clientId]
	p.mutex.RUnlock()
	if !ok {
		return nil, errors.New("client not fount")
	}
	return conn.data, nil
}

// 获取当前所有在线client_id信息
func (p *Gateway) GetAllClientSessions() []map[string]interface{} {
	list := make([]map[string]interface{}, 0)
	for k, v := range p.conns {
		list = append(list, map[string]interface{}{
			k: v.data,
		})
	}
	return list
}

// 根据groupId获取当前所有在线client_id信息
func (p *Gateway) GetClientSessionsByGroup(groupId string) ([]map[string]interface{}, error) {
	p.mutex.RLock()
	defer p.mutex.RUnlock()
	conn, o := p.groups[groupId]
	if !o {
		return nil, errors.New("group is not found")
	}
	list := make([]map[string]interface{}, 0)
	for k, v := range conn.conns {
		list = append(list, map[string]interface{}{
			k: v.data,
		})
	}
	return list, nil
}
